import torch
import torch.nn as nn
from tensordict import TensorDict
from torch import Tensor
from functools import partial
from typing import Tuple
from einops import rearrange
from typing import Literal
from model.net.custom import DiscreteEmbeddingLayer, AttentionFusion

class Critic(nn.Module):
    def __init__(self,
                 *args,
                 encoder_layer: nn.Module,
                 middle_layer: nn.Module,
                 decoder_layer: nn.Module,
                 blocks: nn.Module,
                 norm: nn.Module,
                 action_dim: int = 3,
                 embed_dim: int = 256,
                 norm_layer: nn.LayerNorm = partial(nn.LayerNorm, eps=1e-6),
                 method: Literal["mlp", "transformer", "gpt", "mamba", "mamba2"] = "mlp",
                 device: str = "cpu",
                 seq_len: int = 64,
                 **kwargs
                 ):
        super(Critic, self).__init__()

        self.device = device
        self.encoder_layer = encoder_layer.to(device)
        self.middle_layer = middle_layer.to(device)
        self.decoder_layer = decoder_layer.to(device)
        self.blocks = blocks.to(device)
        self.norm = norm.to(device)
        self.method = method
        self.seq_len = seq_len

        self.action_encoder = DiscreteEmbeddingLayer(
            num_max_tokens=action_dim,
            embed_dim=embed_dim,
        ).to(self.device)

        self.atten_fusion = AttentionFusion(
            state_dim=embed_dim,
            action_dim=embed_dim,
            embed_dim=embed_dim,
        ).to(self.device)

        self.critic_blocks = nn.Sequential(
            nn.Linear(embed_dim, embed_dim),
            nn.Tanh(),
            nn.Linear(embed_dim, embed_dim),
        ).to(self.device)

        self.norm_decoder = norm_layer(embed_dim).to(self.device)

    def forward_encoder(self, features: TensorDict):

        if self.method == "mlp":
            stem_layer = self.encoder_layer['stem_layer'].to(self.device)
            x = stem_layer(features)

            x = self.blocks(x) # extract features
            x = self.norm(x) # normalize

            x = rearrange(x, '... d n -> ... n d') # rearrange dimensions
            reduce_seq_layer = self.middle_layer['reduce_seq_layer']
            x = reduce_seq_layer(x) # reduce sequence dimension
            x = rearrange(x, '... n d -> ... (d n)') # flatten (..., sequence, embedding) to (..., sequence * embedding)

            reduce_embed_layer = self.middle_layer['reduce_embed_layer']
            x = reduce_embed_layer(x) # reduce embedding dimension

        elif self.method == "transformer":
            stem_layer = self.encoder_layer['stem_layer']
            x = stem_layer(features)

            indices_layer = self.encoder_layer['indices']
            indices_embedding = indices_layer(torch.arange(self.seq_len).to(x.device)) # positiona embedding
            x = x + indices_embedding # add position embedding

            if len(x.shape) > 3:
                b, n = x.shape[:2]
                x = rearrange(x, 'b n ... -> (b n) ...', b=b, n=n)
                x = self.blocks(x) # extract features
                x = rearrange(x, '(b n) ... -> b n ...', b=b, n=n)
            else:
                x = self.blocks(x)

            x = self.norm(x)

            x = x[..., -1, :] # get last token

        return x

    def decoder(self, x: Tensor, a: Tensor = None):
        if a is not None:
            action_embedding = self.action_encoder(a)
            action_embedding = action_embedding[..., -1, :]

            x = self.atten_fusion(x, action_embedding)

        x = self.critic_blocks(x)
        x = self.norm_decoder(x)

        x = self.decoder_layer(x)
        return x

    def forward(self, x: TensorDict, a: Tensor = None):
        latent = self.forward_encoder(x)
        pred = self.decoder(latent, a)
        return pred